package com.tencent.angel.sona.ml

/**
  * @author xiongjun
  * @date 2019/11/28 10:43
  * @description
  * @reviewer
  */

import java.{util => ju}

import com.tencent.angel.sona.core.DriverContext
import com.tencent.angel.sona.ml.{Estimator, Pipeline, PipelineStage, Transformer}
import com.tencent.angel.sona.ml.param.{Param, ParamMap, Params}
import com.tencent.angel.sona.ml.util._

import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType

/**
  * A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each
  * of which is either an [[Estimator]] or a [[Transformer]]. When `Pipeline.fit` is called, the
  * stages are executed in order. If a stage is an [[Estimator]], its `Estimator.fit` method will
  * be called on the input dataset to fit a model. Then the model, which is a transformer, will be
  * used to transform the dataset as the input to the next stage. If a stage is a [[Transformer]],
  * its `Transformer.transform` method will be called to produce the dataset for the next stage.
  * The fitted model from a [[Pipeline]] is a [[PipelineModel]], which consists of fitted models and
  * transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as
  * an identity transformer.
  */

class CusPipeline(
                override val uid: String) extends Estimator[CusPipelineModel] with MLWritable {


  def this() = this(Identifiable.randomUID("pipeline"))

  /**
    * param for pipeline stages
    *
    * @group param
    */

  val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline")

  /** @group setParam */

  def setStages(value: Array[_ <: PipelineStage]): this.type = {
    set(stages, value.asInstanceOf[Array[PipelineStage]])
    this
  }

  // Below, we clone stages so that modifications to the list of stages will not change
  // the Param value in the Pipeline.
  /** @group getParam */

  def getStages: Array[PipelineStage] = $(stages).clone()

  /**
    * Fits the pipeline to the input dataset with additional parameters. If a stage is an
    * [[Estimator]], its `Estimator.fit` method will be called on the input dataset to fit a model.
    * Then the model, which is a transformer, will be used to transform the dataset as the input to
    * the next stage. If a stage is a [[Transformer]], its `Transformer.transform` method will be
    * called to produce the dataset for the next stage. The fitted model from a [[Pipeline]] is an
    * [[PipelineModel]], which consists of fitted models and transformers, corresponding to the
    * pipeline stages. If there are no stages, the output model acts as an identity transformer.
    *
    * @param dataset input dataset
    * @return fitted pipeline
    */

  override def fit(dataset: Dataset[_]): CusPipelineModel = {
    transformSchema(dataset.schema, logging = true)
    val theStages = $(stages)
    // Search for the last estimator.
    var indexOfLastEstimator = -1
    theStages.zipWithIndex.foreach { case (stage, index) =>
      stage match {
        case _: Estimator[_] =>
          indexOfLastEstimator = index
        case _ =>
      }
    }
    var curDataset = dataset
    val transformers = ListBuffer.empty[Transformer]
    theStages.zipWithIndex.foreach { case (stage, index) =>
      if (index <= indexOfLastEstimator) {
        val transformer = stage match {
          case estimator: Estimator[_] =>
            var driverCtx = DriverContext.get(dataset.sparkSession.sparkContext.getConf)
            driverCtx.startAngelAndPSAgent()
            val model = estimator.fit(curDataset)
            driverCtx.stopAngelAndPSAgent()
            model
          case t: Transformer =>
            t
          case _ =>
            throw new IllegalArgumentException(
              s"Does not support stage $stage of type ${stage.getClass}")
        }
        if (index < indexOfLastEstimator) {
          curDataset = transformer.transform(curDataset)
        }
        transformers += transformer
      } else {
        transformers += stage.asInstanceOf[Transformer]
      }
    }

    new CusPipelineModel(uid, transformers.toArray).setParent(this)
  }


  override def copy(extra: ParamMap): CusPipeline = {
    val map = extractParamMap(extra)
    val newStages = map(stages).map(_.copy(extra))
    new CusPipeline(uid).setStages(newStages)
  }


  override def transformSchema(schema: StructType): StructType = {
    val theStages = $(stages)
    require(theStages.toSet.size == theStages.length,
      "Cannot have duplicate components in a pipeline.")
    theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur))
  }


  override def write: MLWriter = new CusPipeline.CusPipelineWriter(this)
}


object CusPipeline extends MLReadable[CusPipeline] {


  override def read: MLReader[CusPipeline] = new CusPipelineReader


  override def load(path: String): CusPipeline = super.load(path)

  private[CusPipeline] class CusPipelineWriter(instance: CusPipeline) extends MLWriter {

    SharedReadWrite.validateStages(instance.getStages)

    override protected def saveImpl(path: String): Unit =
      SharedReadWrite.saveImpl(instance, instance.getStages, sc, path)
  }

  private class CusPipelineReader extends MLReader[CusPipeline] {

    /** Checked against metadata when loading model */
    private val className = classOf[Pipeline].getName

    override def load(path: String): CusPipeline = {
      val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
      new CusPipeline(uid).setStages(stages)
    }
  }

  /**
    * Methods for `MLReader` and `MLWriter` shared between [[Pipeline]] and [[PipelineModel]]
    */
  private[sona] object SharedReadWrite {

    import org.json4s.JsonDSL._

    /** Check that all stages are Writable */
    def validateStages(stages: Array[PipelineStage]): Unit = {
      stages.foreach {
        case stage: MLWritable => // good
        case other =>
          throw new UnsupportedOperationException("Pipeline write will fail on this Pipeline" +
            s" because it contains a stage which does not implement Writable. Non-Writable stage:" +
            s" ${other.uid} of type ${other.getClass}")
      }
    }

    /**
      * Save metadata and stages for a [[Pipeline]] or [[PipelineModel]]
      *  - save metadata to path/metadata
      *  - save stages to stages/IDX_UID
      */
    def saveImpl(
                  instance: Params,
                  stages: Array[PipelineStage],
                  sc: SparkContext,
                  path: String): Unit = {
      val stageUids = stages.map(_.uid)
      val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq))))
      DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = Some(jsonParams))

      // Save stages
      val stagesDir = new Path(path, "stages").toString
      stages.zipWithIndex.foreach { case (stage, idx) =>
        stage.asInstanceOf[MLWritable].write.save(
          getStagePath(stage.uid, idx, stages.length, stagesDir))
      }
    }

    /**
      * Load metadata and stages for a [[Pipeline]] or [[PipelineModel]]
      *
      * @return (UID, list of stages)
      */
    def load(
              expectedClassName: String,
              sc: SparkContext,
              path: String): (String, Array[PipelineStage]) = {
      val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)

      implicit val format = DefaultFormats
      val stagesDir = new Path(path, "stages").toString
      val stageUids: Array[String] = (metadata.params \ "stageUids").extract[Seq[String]].toArray
      val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) =>
        val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir)
        DefaultParamsReader.loadParamsInstance[PipelineStage](stagePath, sc)
      }
      (metadata.uid, stages)
    }

    /** Get path for saving the given stage. */
    def getStagePath(stageUid: String, stageIdx: Int, numStages: Int, stagesDir: String): String = {
      val stageIdxDigits = numStages.toString.length
      val idxFormat = s"%0${stageIdxDigits}d"
      val stageDir = idxFormat.format(stageIdx) + "_" + stageUid
      new Path(stagesDir, stageDir).toString
    }
  }

}

/**
  * Represents a fitted pipeline.
  */

class CusPipelineModel private[angel](
                                    override val uid: String,
                                    val stages: Array[Transformer])
  extends Model[CusPipelineModel] with MLWritable with Logging {

  /** A Java/Python-friendly auxiliary constructor. */
  private[sona] def this(uid: String, stages: ju.List[Transformer]) = {
    this(uid, stages.asScala.toArray)
  }


  override def transform(dataset: Dataset[_]): DataFrame = {
    transformSchema(dataset.schema, logging = true)
    stages.foldLeft(dataset.toDF)((cur, transformer) => transformer.transform(cur))
  }


  override def transformSchema(schema: StructType): StructType = {
    stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur))
  }


  override def copy(extra: ParamMap): CusPipelineModel = {
    new CusPipelineModel(uid, stages.map(_.copy(extra))).setParent(parent)
  }


  override def write: MLWriter = new CusPipelineModel.CusPipelineModelWriter(this)
}


object CusPipelineModel extends MLReadable[CusPipelineModel] {

  import CusPipeline.SharedReadWrite


  override def read: MLReader[CusPipelineModel] = new CusPipelineModelReader


  override def load(path: String): CusPipelineModel = super.load(path)

  private[CusPipelineModel] class CusPipelineModelWriter(instance: CusPipelineModel) extends MLWriter {

    SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]])

    override protected def saveImpl(path: String): Unit = SharedReadWrite.saveImpl(instance,
      instance.stages.asInstanceOf[Array[PipelineStage]], sc, path)
  }

  private class CusPipelineModelReader extends MLReader[CusPipelineModel] {

    /** Checked against metadata when loading model */
    private val className = classOf[CusPipelineModel].getName

    override def load(path: String): CusPipelineModel = {
      val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
      val transformers = stages map {
        case stage: Transformer => stage
        case other => throw new RuntimeException(s"PipelineModel.read loaded a stage but found it" +
          s" was not a Transformer.  Bad stage ${other.uid} of type ${other.getClass}")
      }
      new CusPipelineModel(uid, transformers)
    }
  }

}

